#ifndef YOLOV2_H_
#define YOLOV2_H_

#include "DefaultDefine.h"
#include <vector>

namespace hiych {

//yolov2类
class Yolov2
{
private:
    /* data */
public:
    Yolov2(std::string path, size_t numClass);
    ~Yolov2();
public:
    bool loadModel();
    std::vector<DetectObjInfo> getResult(IVE_IMAGE_S& img, const float ResultThresh = 0.35f);
private:
    void setYolov2Parameters();
    void setYolov2SoftwareParameters(
        SAMPLE_SVP_NNIE_CFG_S* pstCfg,
        SAMPLE_SVP_NNIE_PARAM_S *pstNnieParam, 
        SAMPLE_SVP_NNIE_YOLOV2_SOFTWARE_PARAM_S* pstSoftWareParam
    );

    void yolov2Deinit();
    void yolov2SoftwareDeinit(SAMPLE_SVP_NNIE_YOLOV2_SOFTWARE_PARAM_S* pstSoftWareParam);

private:
    SAMPLE_SVP_NNIE_MODEL_S g_stYolov2Model;
    SAMPLE_SVP_NNIE_PARAM_S g_stYolov2NnieParam;
    SAMPLE_SVP_NNIE_YOLOV2_SOFTWARE_PARAM_S g_stYolov2SoftwareParam;

    SAMPLE_SVP_NNIE_CFG_S self;

    std::string path;
    size_t numClass;
};


Yolov2::Yolov2(std::string path, size_t numClass):
    path(path),
    numClass(numClass)
{
    memset_s(&self, sizeof(SAMPLE_SVP_NNIE_CFG_S), 0, sizeof(SAMPLE_SVP_NNIE_CFG_S));
    memset_s(&g_stYolov2Model, sizeof(SAMPLE_SVP_NNIE_MODEL_S), 0, sizeof(SAMPLE_SVP_NNIE_MODEL_S));
    memset_s(&g_stYolov2NnieParam, sizeof(SAMPLE_SVP_NNIE_PARAM_S), 0, sizeof(SAMPLE_SVP_NNIE_PARAM_S));
    memset_s(&g_stYolov2SoftwareParam, sizeof(SAMPLE_SVP_NNIE_YOLOV2_SOFTWARE_PARAM_S), 0, sizeof(SAMPLE_SVP_NNIE_YOLOV2_SOFTWARE_PARAM_S));

    self.pszPic = NULL;
    self.u32MaxInputNum = 1; // max input image num in each batch
    self.u32MaxRoiNum = 0;
    self.aenNnieCoreId[0] = SVP_NNIE_ID_0; // set NNIE core
}

Yolov2::~Yolov2()
{
    this->yolov2Deinit();
}

bool Yolov2::loadModel()
{
    HI_S32 ret;
    ret = SAMPLE_COMM_SVP_NNIE_LoadModel(path.c_str(), &g_stYolov2Model);
    this->setYolov2Parameters();

    SAMPLE_PRT("model.base={ type=%x, frmNum=%u, chnNum=%u, w=%u, h=%u, stride=%u }\n",
        g_stYolov2NnieParam.astSegData[0].astSrc[0].enType,
        g_stYolov2NnieParam.astSegData[0].astSrc[0].u32Num,
        g_stYolov2NnieParam.astSegData[0].astSrc[0].unShape.stWhc.u32Chn,
        g_stYolov2NnieParam.astSegData[0].astSrc[0].unShape.stWhc.u32Width,
        g_stYolov2NnieParam.astSegData[0].astSrc[0].unShape.stWhc.u32Height,
        g_stYolov2NnieParam.astSegData[0].astSrc[0].u32Stride);
    SAMPLE_PRT("model.soft={ class=%u, ori.w=%u, ori.h=%u, bnum=%u, \
        grid.w=%u, grid.h=%u, nmsThresh=%u, confThresh=%u, u32MaxRoiNum=%u }\n",
        g_stYolov2SoftwareParam.u32ClassNum,
        g_stYolov2SoftwareParam.u32OriImWidth,
        g_stYolov2SoftwareParam.u32OriImHeight,
        g_stYolov2SoftwareParam.u32BboxNumEachGrid,
        g_stYolov2SoftwareParam.u32GridNumWidth,
        g_stYolov2SoftwareParam.u32GridNumHeight,
        g_stYolov2SoftwareParam.u32NmsThresh,
        g_stYolov2SoftwareParam.u32ConfThresh,
        g_stYolov2SoftwareParam.u32MaxRoiNum);

    return ret == HI_SUCCESS;
}

void Yolov2::setYolov2Parameters()
{
    g_stYolov2NnieParam.pstModel = &g_stYolov2Model.stModel;

    SAMPLE_COMM_SVP_NNIE_ParamInit(&self, &g_stYolov2NnieParam);
    this->setYolov2SoftwareParameters(&self, &g_stYolov2NnieParam, &g_stYolov2SoftwareParam);
}

void Yolov2::setYolov2SoftwareParameters(
    SAMPLE_SVP_NNIE_CFG_S* pstCfg,
    SAMPLE_SVP_NNIE_PARAM_S *pstNnieParam, 
    SAMPLE_SVP_NNIE_YOLOV2_SOFTWARE_PARAM_S* pstSoftWareParam
)
{
    HI_S32 s32Ret;
    HI_U32 u32ClassNum = 0;
    HI_U32 u32BboxNum;
    HI_U32 u32TotalSize = 0;
    HI_U32 u32DstRoiSize;
    HI_U32 u32DstScoreSize;
    HI_U32 u32ClassRoiNumSize;
    HI_U32 u32TmpBufTotalSize;
    HI_U64 u64PhyAddr = 0;
    HI_U8* pu8VirAddr = NULL;

    pstSoftWareParam->u32OriImHeight = pstNnieParam->astSegData[0].astSrc[0].unShape.stWhc.u32Height;
    pstSoftWareParam->u32OriImWidth = pstNnieParam->astSegData[0].astSrc[0].unShape.stWhc.u32Width;
    pstSoftWareParam->u32BboxNumEachGrid = 5; // 5: 2BboxNumEachGrid
    pstSoftWareParam->u32ClassNum = this->numClass; // 5: class number
    pstSoftWareParam->u32GridNumHeight = 12; // 12: GridNumHeight
    pstSoftWareParam->u32GridNumWidth = 20; // 20: GridNumWidth
    pstSoftWareParam->u32NmsThresh = (HI_U32)(0.3f*SAMPLE_SVP_NNIE_QUANT_BASE);
    pstSoftWareParam->u32ConfThresh = (HI_U32)(0.25f*SAMPLE_SVP_NNIE_QUANT_BASE);
    pstSoftWareParam->u32MaxRoiNum = 10;  // 10: MaxRoiNum
    pstSoftWareParam->af32Bias[0] = 0.52; // 0.52: af32Bias[0] value
    pstSoftWareParam->af32Bias[1] = 0.61; // 0.61: af32Bias[1] value
    pstSoftWareParam->af32Bias[2] = 1.05; // 1.05: af32Bias[ARRAY_SUBSCRIPT_2] value
    pstSoftWareParam->af32Bias[3] = 1.12; // 1.12: af32Bias[ARRAY_SUBSCRIPT_3] value
    pstSoftWareParam->af32Bias[4] = 1.85; // 1.85: af32Bias[ARRAY_SUBSCRIPT_4] value
    pstSoftWareParam->af32Bias[5] = 2.05; // 2.05: af32Bias[ARRAY_SUBSCRIPT_5] value
    pstSoftWareParam->af32Bias[6] = 4.63; // 4.63: af32Bias[ARRAY_SUBSCRIPT_6] value
    pstSoftWareParam->af32Bias[7] = 4.49; // 4.49: af32Bias[ARRAY_SUBSCRIPT_7] value
    pstSoftWareParam->af32Bias[8] = 7.15; // 7.15: af32Bias[ARRAY_SUBSCRIPT_8] value
    pstSoftWareParam->af32Bias[9] = 7.56; // 7.56: af32Bias[ARRAY_SUBSCRIPT_9] value

    /* Malloc assist buffer memory */
    u32ClassNum = pstSoftWareParam->u32ClassNum + 1;
    u32BboxNum = pstSoftWareParam->u32BboxNumEachGrid*pstSoftWareParam->u32GridNumHeight*
        pstSoftWareParam->u32GridNumWidth;
    u32TmpBufTotalSize = SAMPLE_SVP_NNIE_Yolov2_GetResultTmpBuf(pstSoftWareParam);
    u32DstRoiSize = SAMPLE_SVP_NNIE_ALIGN16(u32ClassNum * u32BboxNum * sizeof(HI_U32) * SAMPLE_SVP_NNIE_COORDI_NUM);
    u32DstScoreSize = SAMPLE_SVP_NNIE_ALIGN16(u32ClassNum * u32BboxNum * sizeof(HI_U32));
    u32ClassRoiNumSize = SAMPLE_SVP_NNIE_ALIGN16(u32ClassNum * sizeof(HI_U32));
    u32TotalSize = u32TotalSize + u32DstRoiSize + u32DstScoreSize + u32ClassRoiNumSize + u32TmpBufTotalSize;
    s32Ret = SAMPLE_COMM_SVP_MallocCached("SAMPLE_YOLOV2_INIT", NULL, (HI_U64*)&u64PhyAddr,
        (void**)&pu8VirAddr, u32TotalSize);

    memset_s(pu8VirAddr, u32TotalSize, 0, u32TotalSize);

    SAMPLE_COMM_SVP_FlushCache(u64PhyAddr, (void*)pu8VirAddr, u32TotalSize);

   /* set each tmp buffer addr */
    pstSoftWareParam->stGetResultTmpBuf.u64PhyAddr = u64PhyAddr;
    pstSoftWareParam->stGetResultTmpBuf.u64VirAddr = (HI_U64)((HI_UL)pu8VirAddr);

    /* set result blob */
    pstSoftWareParam->stDstRoi.enType = SVP_BLOB_TYPE_S32;
    pstSoftWareParam->stDstRoi.u64PhyAddr = u64PhyAddr + u32TmpBufTotalSize;
    pstSoftWareParam->stDstRoi.u64VirAddr = (HI_U64)((HI_UL)pu8VirAddr + u32TmpBufTotalSize);
    pstSoftWareParam->stDstRoi.u32Stride = SAMPLE_SVP_NNIE_ALIGN16(u32ClassNum *
        u32BboxNum * sizeof(HI_U32) * SAMPLE_SVP_NNIE_COORDI_NUM);
    pstSoftWareParam->stDstRoi.u32Num = 1;
    pstSoftWareParam->stDstRoi.unShape.stWhc.u32Chn = 1;
    pstSoftWareParam->stDstRoi.unShape.stWhc.u32Height = 1;
    pstSoftWareParam->stDstRoi.unShape.stWhc.u32Width = u32ClassNum *
        u32BboxNum * SAMPLE_SVP_NNIE_COORDI_NUM;

    pstSoftWareParam->stDstScore.enType = SVP_BLOB_TYPE_S32;
    pstSoftWareParam->stDstScore.u64PhyAddr = u64PhyAddr + u32TmpBufTotalSize + u32DstRoiSize;
    pstSoftWareParam->stDstScore.u64VirAddr = (HI_U64)((HI_UL)pu8VirAddr + u32TmpBufTotalSize + u32DstRoiSize);
    pstSoftWareParam->stDstScore.u32Stride = SAMPLE_SVP_NNIE_ALIGN16(u32ClassNum * u32BboxNum * sizeof(HI_U32));
    pstSoftWareParam->stDstScore.u32Num = 1;
    pstSoftWareParam->stDstScore.unShape.stWhc.u32Chn = 1;
    pstSoftWareParam->stDstScore.unShape.stWhc.u32Height = 1;
    pstSoftWareParam->stDstScore.unShape.stWhc.u32Width = u32ClassNum*u32BboxNum;

    pstSoftWareParam->stClassRoiNum.enType = SVP_BLOB_TYPE_S32;
    pstSoftWareParam->stClassRoiNum.u64PhyAddr = u64PhyAddr + u32TmpBufTotalSize +
        u32DstRoiSize + u32DstScoreSize;
    pstSoftWareParam->stClassRoiNum.u64VirAddr = (HI_U64)((HI_UL)pu8VirAddr + u32TmpBufTotalSize +
        u32DstRoiSize + u32DstScoreSize);
    pstSoftWareParam->stClassRoiNum.u32Stride = SAMPLE_SVP_NNIE_ALIGN16(u32ClassNum*sizeof(HI_U32));
    pstSoftWareParam->stClassRoiNum.u32Num = 1;
    pstSoftWareParam->stClassRoiNum.unShape.stWhc.u32Chn = 1;
    pstSoftWareParam->stClassRoiNum.unShape.stWhc.u32Height = 1;
    pstSoftWareParam->stClassRoiNum.unShape.stWhc.u32Width = u32ClassNum;
}

std::vector<DetectObjInfo> Yolov2::getResult(IVE_IMAGE_S& img, const float ResultThresh)
{
    SAMPLE_SVP_NNIE_INPUT_DATA_INDEX_S stInputDataIdx = {0};
    SAMPLE_SVP_NNIE_PROCESS_SEG_INDEX_S stProcSegIdx = {0};
    HI_S32 s32Ret;

    // Fill src data
    self.pszPic = NULL;
    stInputDataIdx.u32SegIdx = 0;
    stInputDataIdx.u32NodeIdx = 0;

    FillNnieByImg(&self, &g_stYolov2NnieParam, 0, 0, &img);
    
    SAMPLE_SVP_NNIE_Forward(&g_stYolov2NnieParam, &stInputDataIdx, &stProcSegIdx, HI_TRUE);

    SAMPLE_SVP_NNIE_Yolov2_GetResult(&g_stYolov2NnieParam, &g_stYolov2SoftwareParam);

    return _NNIE_Detection_PrintResult(
        &g_stYolov2SoftwareParam.stDstScore, 
        &g_stYolov2SoftwareParam.stDstRoi,
        &g_stYolov2SoftwareParam.stClassRoiNum, 
        ResultThresh
    );

}

void Yolov2::yolov2Deinit()
{
    SAMPLE_COMM_SVP_NNIE_ParamDeinit(&g_stYolov2NnieParam);

    yolov2SoftwareDeinit(&g_stYolov2SoftwareParam);

    SAMPLE_COMM_SVP_NNIE_UnloadModel(&g_stYolov2Model);

    SAMPLE_COMM_SVP_CheckSysExit();
}

void Yolov2::yolov2SoftwareDeinit(SAMPLE_SVP_NNIE_YOLOV2_SOFTWARE_PARAM_S* pstSoftWareParam)
{
    if (pstSoftWareParam->stGetResultTmpBuf.u64PhyAddr != 0 && pstSoftWareParam->stGetResultTmpBuf.u64VirAddr != 0) {
        SAMPLE_SVP_MMZ_FREE(pstSoftWareParam->stGetResultTmpBuf.u64PhyAddr,
            pstSoftWareParam->stGetResultTmpBuf.u64VirAddr);
        pstSoftWareParam->stGetResultTmpBuf.u64PhyAddr = 0;
        pstSoftWareParam->stGetResultTmpBuf.u64VirAddr = 0;
        pstSoftWareParam->stDstRoi.u64PhyAddr = 0;
        pstSoftWareParam->stDstRoi.u64VirAddr = 0;
        pstSoftWareParam->stDstScore.u64PhyAddr = 0;
        pstSoftWareParam->stDstScore.u64VirAddr = 0;
        pstSoftWareParam->stClassRoiNum.u64PhyAddr = 0;
        pstSoftWareParam->stClassRoiNum.u64VirAddr = 0;
    }
}

}

#endif